{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    },
    "toc": {
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    },
    "colab": {
      "name": "gnn-basic-graph-spectral-1.ipynb",
      "provenance": []
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "c77aa4e9be0c437a9c89ee73f59f5b5d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_848739c75e294185b10e6c47177798dc",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_5d58a9de0d9343749ac7e0fd8d4ad753",
              "IPY_MODEL_d00366ded630474c9666707163bcaf3e"
            ]
          }
        },
        "848739c75e294185b10e6c47177798dc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "5d58a9de0d9343749ac7e0fd8d4ad753": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_a0ce5db699f2444a89250ea89dfee0f9",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_258ebc3f21aa4a6fafe484a834c713cc"
          }
        },
        "d00366ded630474c9666707163bcaf3e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_f3375cd64b0d43da9a5642984bb850b2",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 9920512/? [00:06&lt;00:00, 1492484.52it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_b868af1d05784b5ea4057d06260f0a08"
          }
        },
        "a0ce5db699f2444a89250ea89dfee0f9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "258ebc3f21aa4a6fafe484a834c713cc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f3375cd64b0d43da9a5642984bb850b2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "b868af1d05784b5ea4057d06260f0a08": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "bf9c85751f4840c8b9fbc23fc1f2eb16": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_3b47b0fa093d4df6ab0d1c4b61cb55f0",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_a49b357bc3c945b798d29f4b4654999d",
              "IPY_MODEL_f6ebf182b573470bb3cec0f8cb45eaea"
            ]
          }
        },
        "3b47b0fa093d4df6ab0d1c4b61cb55f0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a49b357bc3c945b798d29f4b4654999d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_87702662ec224840a28b8884d7f851ba",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_0a4bd7085c3343d4adc715b9bb4fe1d6"
          }
        },
        "f6ebf182b573470bb3cec0f8cb45eaea": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_9fc8039b112b4085a2b44d30d3404b97",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 32768/? [00:03&lt;00:00, 9448.71it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_00efd74f2d3c406585b5affbd10ef788"
          }
        },
        "87702662ec224840a28b8884d7f851ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "0a4bd7085c3343d4adc715b9bb4fe1d6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "9fc8039b112b4085a2b44d30d3404b97": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "00efd74f2d3c406585b5affbd10ef788": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "6f99f16e7f104aaf922e2b33e0f5f209": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_0675d1e2cfc04213ab4b2840d0e50f0b",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_0744ccfd292844dcb779c9080690fd94",
              "IPY_MODEL_679caad6aaaf4bf398ee7102f898c1de"
            ]
          }
        },
        "0675d1e2cfc04213ab4b2840d0e50f0b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "0744ccfd292844dcb779c9080690fd94": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_31725af822334fa289b3897b7c01f9be",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_3b6716ae57b8444e9da5baf0145eb9d7"
          }
        },
        "679caad6aaaf4bf398ee7102f898c1de": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_0446988b9b9a438788315c42594e5c7e",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 1654784/? [00:02&lt;00:00, 669454.68it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_7b2f20a9bfa34dee9931636c996fa9b6"
          }
        },
        "31725af822334fa289b3897b7c01f9be": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "3b6716ae57b8444e9da5baf0145eb9d7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "0446988b9b9a438788315c42594e5c7e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "7b2f20a9bfa34dee9931636c996fa9b6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "6ce7eaa930bc494ca2bca68ad0d5dd83": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_1987d5ae6ea641eda46934c307c88b0b",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_64a9fc9418084713a071b23d2a80e703",
              "IPY_MODEL_0416f10f14a74815a5b6a4f49eca5e9c"
            ]
          }
        },
        "1987d5ae6ea641eda46934c307c88b0b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "64a9fc9418084713a071b23d2a80e703": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_42ca724707a848378c7e8dee107eb10e",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_1a22a1e62b9e4d1996d2c8570707be7f"
          }
        },
        "0416f10f14a74815a5b6a4f49eca5e9c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_7585afec669a4d1fae23cb693e1024ad",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 8192/? [00:00&lt;00:00, 13048.05it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_5aa2ee1a669943128faf8f8338c7d2e4"
          }
        },
        "42ca724707a848378c7e8dee107eb10e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "1a22a1e62b9e4d1996d2c8570707be7f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "7585afec669a4d1fae23cb693e1024ad": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "5aa2ee1a669943128faf8f8338c7d2e4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xp2NZRkVg0cL",
        "colab_type": "text"
      },
      "source": [
        "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
        "- Author: Sebastian Raschka\n",
        "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QfRwmoX9h3U1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        },
        "outputId": "48cbb5cb-55df-45ba-a88e-ad1dd84f5abb"
      },
      "source": [
        "!pip install -q IPython\n",
        "!pip install -q ipykernel\n",
        "!pip install -q watermark\n",
        "!pip install -q matplotlib\n",
        "!pip install -q sklearn\n",
        "!pip install -q pandas\n",
        "!pip install -q pydot\n",
        "!pip install -q hiddenlayer\n",
        "!pip install -q graphviz\n",
        "!pip install -q tensorwatch"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\u001b[K     |████████████████████████████████| 194kB 2.6MB/s \n",
            "\u001b[K     |████████████████████████████████| 143kB 8.1MB/s \n",
            "\u001b[?25h  Building wheel for tensorwatch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for pydotz (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pPudzCGEg0cM",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        },
        "outputId": "428824f6-8d85-495c-e13f-d226ad539c01"
      },
      "source": [
        "%load_ext watermark\n",
        "%watermark -a 'Sebastian Raschka' -v -p torch"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Sebastian Raschka \n",
            "\n",
            "CPython 3.6.9\n",
            "IPython 5.5.0\n",
            "\n",
            "torch 1.5.1+cu101\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5u0FWpUhg0cR",
        "colab_type": "text"
      },
      "source": [
        "- Runs on CPU or GPU (if available)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VfiGTiPg0cS",
        "colab_type": "text"
      },
      "source": [
        "# Basic Graph Neural Network with Spectral Graph Convolution on MNIST"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xuXNSSaWg0cT",
        "colab_type": "text"
      },
      "source": [
        "Implementing a very basic graph neural network (GNN) using a spectral graph convolution. \n",
        "\n",
        "Here, the 28x28 image of a digit in MNIST represents the graph, where each pixel (i.e., cell in the grid) represents a particular node. The feature of that node is simply the pixel intensity in range [0, 1]. \n",
        "\n",
        "Here, the adjacency matrix of the pixels is basically just determined by their neighborhood pixels. Using a Gaussian filter, we connect pixels based on their Euclidean distance in the grid.\n",
        "\n",
        "In the related notebook, [./gnn-basic-1.ipynb](./gnn-basic-1.ipynb), we used this adjacency matrix $A$ to compute the output of a layer as \n",
        "\n",
        "$$X^{(l+1)}=A X^{(l)} W^{(l)}.$$\n",
        "\n",
        "Here, $A$ is the $N \\times N$ adjacency matrix, and $X$ is the $N \\times C$ feature matrix (a  2D coordinate array, where $N$ is the total number of pixels -- $28 \\times 28 = 784$ in MNIST). $W$ is the weight matrix of shape $N \\times P$, where $P$ would represent the number of classes if we have only a single hidden layer.\n",
        "\n",
        "In this notebook, we modify this code using spectral graph convolution, i.e.,\n",
        "\n",
        "$$X^{(l+1)}=V\\left(V^{T} X^{(l)} \\odot V^{T} W_{\\text {spectral }}^{(l)}\\right).$$\n",
        "\n",
        "Where $V$ are the eigenvectors of the graph Laplacian $L$, which we can compute from the adjacency matrix $A$. Here, $W_{\\text {spectral }}$ represents the trainable weights (filters).\n",
        "\n",
        "- Inspired by and based on Boris Knyazev's tutorial at https://towardsdatascience.com/tutorial-on-graph-neural-networks-for-computer-vision-and-beyond-part-2-be6d71d70f49."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GzdrI6hQg0cT",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v4gP_GtKg0cU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import time\n",
        "import numpy as np\n",
        "from scipy.spatial.distance import cdist\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torchvision import datasets\n",
        "from torchvision import transforms\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.utils.data.dataset import Subset\n",
        "\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    torch.backends.cudnn.deterministic = True"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NIGxmJTUg0cY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f3gguL__g0ca",
        "colab_type": "text"
      },
      "source": [
        "## Settings and Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lAieeZ2Xg0cb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##########################\n",
        "### SETTINGS\n",
        "##########################\n",
        "\n",
        "# Device\n",
        "DEVICE = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "# Hyperparameters\n",
        "RANDOM_SEED = 1\n",
        "LEARNING_RATE = 0.05\n",
        "NUM_EPOCHS = 50\n",
        "BATCH_SIZE = 128\n",
        "IMG_SIZE = 28\n",
        "\n",
        "# Architecture\n",
        "NUM_CLASSES = 10"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HbwmUxRog0ce",
        "colab_type": "text"
      },
      "source": [
        "## MNIST Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1lPXkVH-g0cf",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 424,
          "referenced_widgets": [
            "c77aa4e9be0c437a9c89ee73f59f5b5d",
            "848739c75e294185b10e6c47177798dc",
            "5d58a9de0d9343749ac7e0fd8d4ad753",
            "d00366ded630474c9666707163bcaf3e",
            "a0ce5db699f2444a89250ea89dfee0f9",
            "258ebc3f21aa4a6fafe484a834c713cc",
            "f3375cd64b0d43da9a5642984bb850b2",
            "b868af1d05784b5ea4057d06260f0a08",
            "bf9c85751f4840c8b9fbc23fc1f2eb16",
            "3b47b0fa093d4df6ab0d1c4b61cb55f0",
            "a49b357bc3c945b798d29f4b4654999d",
            "f6ebf182b573470bb3cec0f8cb45eaea",
            "87702662ec224840a28b8884d7f851ba",
            "0a4bd7085c3343d4adc715b9bb4fe1d6",
            "9fc8039b112b4085a2b44d30d3404b97",
            "00efd74f2d3c406585b5affbd10ef788",
            "6f99f16e7f104aaf922e2b33e0f5f209",
            "0675d1e2cfc04213ab4b2840d0e50f0b",
            "0744ccfd292844dcb779c9080690fd94",
            "679caad6aaaf4bf398ee7102f898c1de",
            "31725af822334fa289b3897b7c01f9be",
            "3b6716ae57b8444e9da5baf0145eb9d7",
            "0446988b9b9a438788315c42594e5c7e",
            "7b2f20a9bfa34dee9931636c996fa9b6",
            "6ce7eaa930bc494ca2bca68ad0d5dd83",
            "1987d5ae6ea641eda46934c307c88b0b",
            "64a9fc9418084713a071b23d2a80e703",
            "0416f10f14a74815a5b6a4f49eca5e9c",
            "42ca724707a848378c7e8dee107eb10e",
            "1a22a1e62b9e4d1996d2c8570707be7f",
            "7585afec669a4d1fae23cb693e1024ad",
            "5aa2ee1a669943128faf8f8338c7d2e4"
          ]
        },
        "outputId": "6508430a-78ea-4c67-94c1-73609bb53798"
      },
      "source": [
        "train_indices = torch.arange(0, 59000)\n",
        "valid_indices = torch.arange(59000, 60000)\n",
        "\n",
        "custom_transform = transforms.Compose([transforms.ToTensor()])\n",
        "\n",
        "\n",
        "train_and_valid = datasets.MNIST(root='data', \n",
        "                                 train=True, \n",
        "                                 transform=custom_transform,\n",
        "                                 download=True)\n",
        "\n",
        "test_dataset = datasets.MNIST(root='data', \n",
        "                              train=False, \n",
        "                              transform=custom_transform,\n",
        "                              download=True)\n",
        "\n",
        "train_dataset = Subset(train_and_valid, train_indices)\n",
        "valid_dataset = Subset(train_and_valid, valid_indices)\n",
        "\n",
        "train_loader = DataLoader(dataset=train_dataset, \n",
        "                          batch_size=BATCH_SIZE,\n",
        "                          num_workers=4,\n",
        "                          shuffle=True)\n",
        "\n",
        "valid_loader = DataLoader(dataset=valid_dataset, \n",
        "                          batch_size=BATCH_SIZE,\n",
        "                          num_workers=4,\n",
        "                          shuffle=False)\n",
        "\n",
        "test_loader = DataLoader(dataset=test_dataset, \n",
        "                         batch_size=BATCH_SIZE,\n",
        "                         num_workers=4,\n",
        "                         shuffle=False)\n",
        "\n",
        "# Checking the dataset\n",
        "for images, labels in train_loader:  \n",
        "    print('Image batch dimensions:', images.shape)\n",
        "    print('Image label dimensions:', labels.shape)\n",
        "    break"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c77aa4e9be0c437a9c89ee73f59f5b5d",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "bf9c85751f4840c8b9fbc23fc1f2eb16",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6f99f16e7f104aaf922e2b33e0f5f209",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6ce7eaa930bc494ca2bca68ad0d5dd83",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n",
            "Processing...\n",
            "Done!\n",
            "\n",
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/pytorch/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Image batch dimensions: torch.Size([128, 1, 28, 28])\n",
            "Image label dimensions: torch.Size([128])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vAyzbiS6g0ci",
        "colab_type": "text"
      },
      "source": [
        "## Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2lx-Lzc3g0ck",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def precompute_adjacency_matrix(img_size):\n",
        "    col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))\n",
        "    \n",
        "    # N = img_size^2\n",
        "    # construct 2D coordinate array (shape N x 2) and normalize\n",
        "    # in range [0, 1]\n",
        "    coord = np.stack((col, row), axis=2).reshape(-1, 2) / img_size\n",
        "\n",
        "    # compute pairwise distance matrix (N x N)\n",
        "    dist = cdist(coord, coord, metric='euclidean')\n",
        "    \n",
        "    # Apply Gaussian filter\n",
        "    sigma = 0.05 * np.pi\n",
        "    A = np.exp(- dist / sigma ** 2)\n",
        "    A[A < 0.01] = 0\n",
        "    A = torch.from_numpy(A).float()\n",
        "    \n",
        "    return A\n",
        "\n",
        "    \"\"\"\n",
        "    # Normalization as per (Kipf & Welling, ICLR 2017)\n",
        "    D = A.sum(1)  # nodes degree (N,)\n",
        "    D_hat = (D + 1e-5) ** (-0.5)\n",
        "    A_hat = D_hat.view(-1, 1) * A * D_hat.view(1, -1)  # N,N\n",
        "    \n",
        "    return A_hat\n",
        "    \"\"\"\n",
        "\n",
        "\n",
        "def get_graph_laplacian(A):\n",
        "    # From https://towardsdatascience.com/spectral-graph-convolution-\n",
        "    #   explained-and-implemented-step-by-step-2e495b57f801\n",
        "    #\n",
        "    # Computing the graph Laplacian\n",
        "    # A is an adjacency matrix of some graph G\n",
        "    N = A.shape[0] # number of nodes in a graph\n",
        "    D = np.sum(A, 0) # node degrees\n",
        "    D_hat = np.diag((D + 1e-5)**(-0.5)) # normalized node degrees\n",
        "    L = np.identity(N) - np.dot(D_hat, A).dot(D_hat) # Laplacian\n",
        "    return torch.from_numpy(L).float()"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IMmnZWkGg0co",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 304
        },
        "outputId": "3ebb148f-94f9-4cf8-c5e9-c2692ec9afca"
      },
      "source": [
        "A = precompute_adjacency_matrix(28)\n",
        "plt.imshow(A, vmin=0., vmax=1.)\n",
        "plt.colorbar()\n",
        "plt.show()"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n",
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAATQAAAD8CAYAAAD5TVjyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2df3CcxZnnP8+MbPlKtiXLkoWxbCwTAee7qwDrYkllQ1gMwSYsULfZLNRmcba4ddUFtpJK6vac2qvcXe7qKrmrChuqWG6dkI1JZUMImyxex+AQQ5bc1UIgIcsFiLGwjS2Dfo9kR5QlzzvP/fH2iLFm3nfekWb0zrzzfFJd877dPf326xHf9NP9dD+iqhiGYSSBVNwdMAzDqBYmaIZhJAYTNMMwEoMJmmEYicEEzTCMxGCCZhhGYqiJoInIDhE5IiIDIrKnFs8wDKNxEZFviMiIiPwqoFxE5AGnIa+IyNVR2q26oIlIGngQ2AlsBe4Ska3Vfo5hGA3NN4EdIeU7gX6XdgMPRWm0FiO0a4ABVT2mqrPAo8DtNXiOYRgNiqo+B0yEVLkdeER9ngc6RGR9uXZbqtXBAjYApwruB4Hfnl9JRHbjKy/Ssvy3Vkk7ej5b1JikUrCsBZ2dhVKbGgQEwXY8GEY4Z8mMqWr3Qr9/8++26fiEF6nuz1+ZeRU4V5C1V1X3VvC4UjqyAXgn7Eu1ELRIuJfbC9C+fJ32f2IP3U8dIzs0PK8ipNvWIL3teCcH0WwJ0WttRUTInTtXVOZXEDDBM5qcH+vjby3m++MTHj87tClS3fT6o+dUddtinrcQamFyngY2Ftz3urxA9HyW7qeOMbpjCy0X9RSVe5kMmpkivakXaSnWYJ2ZQVVJrVgR8AD1Rc0wjAWjQC7i/6pAxToCtRG0F4F+EekTkeXAncD+cl/KDg2bqBlGHaMo59WLlKrAfuBut9p5LTClqqHmJtTA5FTVrIjcBxwC0sA3VPXVKN/1RQ1Gd2yh+ymKzE8vkyENpDf1ljQ/dWYGWltJrVhR2vzMi5qZn4axIKo0+kJEvgNcD3SJyCDwn4FlAKr6v4GDwC3AAPAu8CdR2q3JHJqqHnQdqhgTNcOoTxTFq9J/N6p6V5lyBe6ttN263Clg5qdh1Cc5NFKKi7oQNEkVdyMvapnr+0iv7Swq9zIZdCJD+qIeSKWLynVmBvVyyLLlpR9qomYYFaGAh0ZKcVEXgsayFtJr1hRlZ4eG6Xj6DcZuvbz0SG1yCj1zlpbNG0uP1M7PAthIzTCqhI3QIqCzs8ia9pKi5o1P0H0oxPw8cybc/Dw/a+anYVQBBc6rRkpxUReChoJ3cjBQ1GxOzTDiRyOam2ZyAprNmqgZRj2j4EVMcVE3ggYmaoZRz/g7BaKluKgrQQMTNcOoXwQvYoqLuhM0mCdqHe1F5YWilu5ZV1RuomYY1cdfFJBIKS7qQ9BKvP+cqK1cSXr16qLy7NAwXU++ydR1fcGiNpEh3bMuWNTMT80wIuP7odkIrSyCIK2tRfmazZJ9ewjp6izt0jE8QvtzxxnbeWmwn9r0dKhLB5ifmmFEJacSKcVFXQiaqiJSWtTIeaFzat7wSPic2uSU+akZRhWwEVoF5M6dCxQ1WygwjPhRBI9UpBQXdSNoYKJmGPWOmZwVYqJmGPWJIsxqOlKKi7oTNDBRM4x6xHesTUVKcVGXggYFolbCraIiP7Xu4iA3haIWePSQaulFCjBRM5oWWxRYBPkTZ0NFra0tUNS6fjjA1PXBI7Xc6DgtG9YHjtTwPBM1w3CoCp6mIqW4qB9BCxAHPT+LpFPB5ufwCNK5prRLx+goHT89EWh+5s6eRX8T4qfmjvc289MwfHJIpBQXZQVNRL4hIiMi8quCvE4ReVpEjrrPNS5fROQBERkQkVdE5OrIPQkRB5tTM4z48RcFWiKluIgyQvsmsGNe3h7gsKr2A4fdPcBOoN+l3cBDFfXGRM0w6pZELAqo6nPAxLzs24F97nofcEdB/iPq8zzQISLrK+qRiZph1C2eSqQUFwuV0p6CoJ9DQF4hNgCnCuoNurwiRGS3iLwkIi+dZ+bCQhM1w6g7mmKngIufV/EZlaq6V1W3qeq2ZZRYSTRRM4y6I6epSCkuFvrk4bwp6T5HXP5pYGNBvV6XtzCiiFoZP7XUqlVF5XOidvOW4BB5mSnSGzeUfL75qRnNiL85PZkjtP3ALne9C3iiIP9ut9p5LTBVYJoujDKiBsF+atkTp0h1BI/Uug4cIXPzZcF+aiNjtFwSECLP/NSMJkMRzms6UoqLKG4b3wH+CbhcRAZF5B7gS8BNInIUuNHdAxwEjgEDwNeAT1WllyHiEOanRs7De2coNERe5zPHg/3UpqfRqbPmp2YY+H/O9e5YW9ZhRFXvCijaXqKuAvcutlMBHfHFoUTMv9y5c76otLb6I6fCrznzM72plzT+yKsQ3/yE0R1b6H7Kvy/Ey2RIA+lNvXgnB+dEbK79mRlobSW1YsXciDFqvw2jsYjXaTYK9bNTIAq2UGAYsaHU/witsQQNTNQMI0aSuigQLyZqhrHkKNEOd7QDHheCiZphLCl+GLuWSCkuGlfQIJpLR8DqpHdyEOlYTaqtrag8L2rj2/tKr47mRW3D+vDz1CxEnpEoLNBw7Qlz6ZiZgXQ6OETeiZOkutcGjtQ6Dx5h/PeuCPZTG5ugZXOInxrm0mEkByW5OwXqizKiFhgiTxVv8O1gP7VMhq4fBZufuelpC5FnNBU2QlsqbE7NMGqKqlR1hCYiO0TkiDs/cU+J8k0i8qyIvOzOV7ylXJvJETQwUTOMGuIvClRn65OIpIEH8c9Q3ArcJSJb51X7T8BjqnoVcCfwV+XaTZaggYmaYdSMqsYUuAYYUNVjqjoLPIp/nmIhCqx21+3A2+UaTZ6ggYmaYdQAf1Egsh9aV/68Q5d2z2suytmJ/wX4hIgM4u8T/7NyfUymoIGJmmHUgAp2Cozlzzt0ae8CHncX8E1V7QVuAb4lIqGalVxBg8X7qa1eGeqnlvlwH+mutUXlXiaDTmRI96wLFjUvZ35qRkNR5Z0CUc5OvAd4DEBV/wlYAXSFNZpsQYPF+amdeptUT3fgSK3j8FHGbg04T21yCj37m1CXDjA/NaOxqGKQlBeBfhHpE5Hl+JP+++fVOYk71UdE/iW+oI2GNZp8QYOF+6nlvFDz0xsbDzc/z5wxPzUjMajC+VwqUirflmaB+4BDwOv4q5mvisgXReQ2V+1zwJ+KyD8D3wE+6Y4oCyS+TVdLjZ2nZhiLwjc5qzcGUtWD+JP9hXlfKLh+DfhgJW02xwgtjy0UGMaisJ0C9YaJmmEsiArdNmKh+QQNTNQMY0FUd+tTLYgSJGWj20/1moi8KiKfdvmdIvK0iBx1n2tcvojIA25/1isicnWtX2JBVEPUOtqLygtFLd2zrqjcRM1oZHIurkC5FBdRpDQLfE5VtwLXAve6PVd7gMOq2g8cdvfg783qd2k38FDVe10tyvmp5UqfaTYnam1tpFevLirPDg3T9eSbnPlQX7CojU2Yn5rRUPirnOlIKS7KCpqqvqOqv3DXZ/GXWDfg77va56rtA+5w17cDj6jP80BHPihxXbLAEHmazeINjyBdnaVdOoZHWP3T44ztvDTYpWN6OtxPLSU2UjPqhsQdwS0im4GrgBeAnoIgwkNA/r/aKHu0EJHd+X1e55mZX7y01GhOzRseCZ9Tm5wy89NoKJJgcgIgIiuBvwM+o6pnCsucs1tFjlKquje/z2sZAdHHlxJbKDCMUBKzyikiy/DF7Nuq+n2XPZw3Jd3niMuPskerPjFRM4xQkrDKKcDDwOuq+pWCov3ALne9C3iiIP9ut9p5LTBVYJrWPyZqhlESVSGrqUgpLqI8+YPAHwM3iMgvXboF+BJwk4gcBW509+BvZTgGDABfAz5V/W7XGBM1wyhJvZucZfdyqur/gcBZvu0l6itw7yL7FT9R9n4uWz53asbc1wr3fnoe3pkLphsv2PvZdSCLNzZ+QXnh3s/siVOQ8y5s3+39lBL7Tsv12zAWQ34OrZ5pzp0CUYlynlqYn9qqlYHOt10H3mBye39wiLzhUVo2XhwcIs/zSp8QUqbfhrEY6n2EZoJWjsX6qXWuCTx6aM0/Hg8PkXcm5Dw1d2qHmZ/GUpE4P7SmxebUDANIkB9a02OiZjQ5qpDNpSKluDBBqwQTNaPJMZMzaZioGU2KzaElFRM1o0lRlUgpLkzQFkoUUQtz6VjTHhoib+wjW0qvjuZFrffiks/Pi5q5dBi1wBYFkkwUP7WgEHknTpEKOHooOzTM2n/4NRMfvTzYT210nJbNm8xPzVgyVG0OLfmU81MLC5F3+p3go4cyGdb+uIyf2uQZ81MzlhDBy6UipbgwQasGNqdmNAk2h9YsmKgZCScx56EZETFRM5KM+n8qUVJcmKBVGxM1I8HYKmczYqJmJBC1RYEmJopLR8DqpHdyEGlfFeqnNnFDH+m1nUXlc6K2/iJIFYcTsxB5xmIwk7OZCXPpmJmBdDrYT+2tU6TWdQWO1NYceoOxW0P81CanaNm8MThEHubSYVSOrXI2O2VELdBPTRXv1OlgP7XxCboPBZufubNnw83P87NmfhoV4Y++TNAMm1MzEkLDu22IyAoR+ZmI/LOIvCoi/9Xl94nICyIyICLfFZHlLr/V3Q+48s21fYUGwUTNSABJmEObAW5Q1fcDVwI7XHi6LwP3q+r7gAxwj6t/D5Bx+fe7egaYqBkNjSLkcqlIKS7KPll9fuNul7mkwA3A4y5/H3CHu77d3ePKt7vYngaYqBkNjUZMcRE1cnpaRH6JHx39aeBNYFJVs67KILDBXW8ATgG48ilgbYk2d4vISyLy0nlKhGNLMiZqRiNS5UUBEdkhIkfc9NSegDofF5HX3HTX35ZrM5KgqaqnqlcCvcA1wBWRehze5l5V3aaq25YRcMxNklmsn9rKNlKrVhWV50Vt8kObSXd3F5V7mQw6kSHdsy5Y1MxPzQiiSkM0EUkDDwI7ga3AXSKydV6dfuDzwAdV9V8BnynXbkXGrqpOAs8CHwA6RCT/X0QvcNpdnwY2ug61AO3AOEYxi/FTO/0Oqe61gSO19p8cY+yj7ys9UpucQqenQ106wPzUjGKqOEK7BhhQ1WOqOgs8ij9dVcifAg+qasZ/to6UazTKKme3iHS4638B3AS8ji9sH3PVdgFPuOv97h5X/oyLpm6UYqF+ajkv1Pz0RkfDzc/JKfNTMypCgVxOIiWgKz+l5NLuec3NTU05Cqet8lwGXCYi/1dEnheRHeX6WPyXXMx6YJ8bIqaAx1T1gIi8BjwqIv8deBl42NV/GPiWiAwAE8CdEZ7R3OTFoYTu586d80WltdUftRV+zZmf6U29pPHNyUJ88xNGd2yh+yn/vhAvkyENpDf14p0cnDsUcq79mRlobSW1YsWcGRy130YCUSC6j9mYqm5b5BNbgH7genwr8DkR+TfOUgz8Qiiq+gpwVYn8Y/jDxvn554A/iN5nAzBRMxqCKv7Mc1NTjsJpqzyDwAuqeh44LiJv4Avci0GN2k6BesJWP416p3p+Gy8C/c5Bfzm+Jbd/Xp2/xx+dISJd+CbosbBGTdDqDRM1o26JtiAQZVHAuXTdBxzCn5N/TFVfFZEvishtrtohYNxNbz0L/AdVDV1gjDKHZiw1UczPZcvnViPnvlZofmoOb3LqgvJC87PrScUbvnDRqND8zJ44BTnvwvbN/DSq+NOq6kHg4Ly8LxRcK/BZlyJhI7R6JYqfWljcz7Y20h3tReXZoWG6nnyTqev6SPesKyqfC5F38UXBI7Vs1vzUmhEFzUmkFBcmaPVMuRB56VSw+Tk8gnSuKe3SMTxC+3PHGdt5afDRQ9PvhofIS4mZn02JREzxYIJW79RoTs0bHrE5NaNy6nwzpwlaI2ALBUa9YIJmVAUTNSNu8o61UVJMmKA1EiZqRswk4YBHo54wUTPiJCfRUkyYoDUiUUQtzKVjTXvo0UOjN28JD5G3qbfk8/OiVnIzfZl+G42BaLQUFyZojcoi/NSyJ06Ral8dOFLrOnCEzEcuCw6RNzxKyyUBIfJmZsDzTNSSSNQFARM0Y0Es0E+NnIc3NBwaIq/z2ePBIfKmp9Gps+F+ath5askj4oKALQoYC8bm1IylxEZoRs0xUTOWilzEFBMmaEnBRM2oNeaHZiwpJmpGjbFVTmNpMVEzaklS5tBcbM6XReSAu+8TkRdcTL3vulMnEZFWdz/gyjfXputGIFFELSxE3pp2Um1tReV5URu7aUvp1dG8qPVeHO6nZkcPGTWikhHap/FPlszzZeB+VX0fkAHucfn3ABmXf7+rZyw15fzUJPjooezxt0JD5K098GvGb70i2E9tdJyWvkuC/dQ0ZyO1BiURJqeI9AIfBb7u7gW4AXjcVdkH3OGub3f3uPLtrr6x1JTzUwsKkaeKN/h2sJ9aJkPX08HmZ256Otz8zGbN/GxElMRsffpL4M95b0F2LTDpzgWHC2PqzcXbc+VTrv4FiMjufMy+88zMLzaqhc2pGdWk0efQRORWYERVf17NB6vqXlXdpqrblhGwTcaoDiZqRpVIgsn5QeA2ETmBH679BuCrQIeI5P9KC2PqzcXbc+XtQGikFmMJMFEzqkGjj9BU9fOq2quqm/Fj5z2jqn+EH1bqY67aLuAJd73f3ePKn3HRW4y4MVEzFkujC1oI/xH4rIgM4M+RPezyHwbWuvzPAnsW10WjqpioGQskqrkZp8lZUVxOVf0J8BN3fQy4pkSdc8AfVKFvRq0oE/dTWluRlpa5UzPmvpaP+9l7ManZWXLT0xeU5+N+Zq7vo+Pp83jjExeUe5kMac2RvqiH7NtDpeN+LluOlIg5Wq7fxhIR4wpmFGynQLMS5tIxMwPpdLCf2lunSPV0B47UOp5+g7FbLy89UpucQs+cpWVzwHlqTshspFaf1PsIzQStmSkjaqF+amEh8sYn6D4UYn6eORNufp6fNfOzXknwHJqRBGxOzYhKA8yhmaAZJmpGdGyEZjQEJmpGBCQXLcWFCZrxHiZqRoNjgmZcSDVEraO9qLxQ1NI964rKTdQaBDM5jYaj3NFDOQ0/T23lStKrVxeVZ4eG6XryTaau6wsWtYkM6Z51waLm5ew8tbiwRQGjYSl39FBLS7Cf2ttDSFdnaZeO4RHanzvO2M5Lg/3UpqdDXTrA/NRiw0ZoRsOyQPOTnBfupzY8Ej6nNjllfmr1igma0dDYQoHhEGyV00gCJmoGVH0OTUR2iMgRF38k8BALEfl9EVER2VauTRM0IxomagZUzeQUkTTwILAT2ArcJSJbS9RbhR/P5IUo3TNBM6JjomZUbw7tGmBAVY+p6iz+4bG3l6j33/ADLZ2L0qgJmlEZUUSthFtFRX5q3d1F5YWiRipd3H4+RF6pRYoy/TaiU4HJ2ZWPGeLS7nlNzcUecRTGJfGfJXI1sFFVfxi1fxWdh2YYwHviEHSeWsCZZnPnqfWsI42/mllIdmiYrh/mmLp+Cx0/TZEdGr6g3MtkSGWztGxYj/fOUPF5bTMzc+4kOlMi8I6dp7Z4ov/Tjalq2TmvIEQkBXwF+GQl37MRmrEwyvmppYPjfnrDI0jnmtIuHaOjdPz0RHCIvLNn0d+E+Kk5kTPzswZoVVc552KPOArjkgCsAv418BMXz+RaYH+5hQETNGPh2Jxa81G9ObQXgX4R6ROR5fjxSvbPPUZ1SlW7VHWzi2fyPHCbqr4U1qgJmrE4TNSaimq5bbiYvfcBh4DXgcdU9VUR+aKI3LbQ/kWaQ3NDvrOAB2RVdZuIdALfBTYDJ4CPq2rGRUn/KnAL8C7wSVX9xUI7aDQAZebUUitWQIl5rbk5tU29/pxaJnNBeT5GweiOLXQ/Rck5tTSQ3tSLd3Kw5Jwara2kVqzw96BW0G8jgCr+U6nqQeDgvLwvBNS9PkqblYzQfldVryyY6NsDHFbVfuAw70V32gn0u7QbeKiCZxiNio3Ukk9Uc7NBtz7dDuxz1/uAOwryH1Gf5/EDEq9fxHOMRsFELdEIyTltQ4EficjPC/xJelT1HXc9BOT/0sr6lwCIyO68j8p5SiyxG41JFfzUUqtWFZXPidrNW0iv7SwqnxO1jRtKPt/81KpDUgTtd1T1anxz8l4Rua6w0EVGr+g1VHWvqm5T1W3LCPgjMxqTcuepQaCoZU+cItURPFLrOnCEzM2XBY7UciNjtFwSECJvZgY8z0RtMSTB5FTV0+5zBPgB/raF4bwp6T5HXPVy/iVGM7BAPzVyHt47Q6Eh8jqfOR7spzY9jU6dNT+1WtHogiYibW6DKCLSBnwE+BW+z8guV20X8IS73g/cLT7XAlMFpqnRTNicWrKIaG7GaXJGcdvoAX7ge2PQAvytqj4lIi8Cj4nIPcBbwMdd/YP4LhsD+G4bf1L1XhuNg7l0JIs6/+coK2iqegx4f4n8cWB7iXwF7q1K74xkYKKWGOI8vDEKtlPAWBrM/EwE9W5ymqAZS4eJWmOTcMdaw6icKC4dYSHyOlaTamsrKs+L2vj2vtKro3lR27A+/Dw1C5EXjgmaYcwjzKVjZgbS6eAQeSdOkupeGzhS6zx4hPHfuyLYT21sgpbNIX5qmEtHEEnaKWAY1aWMqAWGyFPFG3w72E8tk6HrR8HmZ2562kLkLQLJaaQUFyZoRnzYnFpjYXNohlEGE7WGwkxOwyiHiVrjYCM0w4iAiVpDYCM0w4iKiVr9YyM0w6iAxfqprV4Z6qeW+XAf6a61ReVeJoNOZEj3rAsWNS/X3H5qSjWjPtUEEzSj/liMn9qpt0n1dAeO1DoOH2Xs1oDz1Can0LO/CXXpgOb1UzM/NMNYKAv1U8t5oeanNzYebn6eOWN+amGoRksxYYJm1C82p1Z32AjNMBaDiVr9YI61hlEFTNTqBlsUMIxqYKJWF5igGUa1qIaodbQXlReKWrpnXVG5iZpDsUUBw6gq5fzUcqXPNJsTtbY20qtXF5Vnh4bpevJNznyoL1jUxiaa3k8tEYsCItIhIo+LyK9F5HUR+YCIdIrI0yJy1H2ucXVFRB4QkQEReUVErq7tKxhNxwJD5Gk2izc8gnR1lnbpGB5h9U+PM7bz0mCXjunpcD+1lCR7pJaQRYGvAk+p6hX4AVNeB/YAh1W1Hzjs7sEPRtzv0m7goar22DCgZnNq3vBI+Jza5FTTmp+JcKwVkXbgOuBhAFWdVdVJ4HZgn6u2D7jDXd8OPKI+zwMd+YDEhlFVbKFgadFohzvW+wGPfcAo8Dci8rKIfN0FHO4pCCA8hB+/E2ADcKrg+4Mu7wJEZLeIvCQiL51nZn6xYUTDRG1pSYDJ2QJcDTykqlcB07xnXgLkY3FW9BqquldVt6nqtmWU2MJiGFExUVsyGt7kxB9hDarqC+7+cXyBG86bku5zxJWfBjYWfL/X5RlG7TBRqz0K5DRaiomygqaqQ8ApEbncZW0HXgP2A7tc3i7gCXe9H7jbrXZeC0wVmKaGUTuiiFqYS8ea9kCXjjk/taCjh5yohYbIK7WZvky/644EmJwAfwZ8W0ReAa4E/gfwJeAmETkK3OjuAQ4Cx4AB4GvAp6raY8MII8p5amGitmploPNt14E3mNzeHxwib3iUlo0XB4fI87yGF7VqmpwiskNEjjgXrz0lyj8rIq8596/DInJJuTaL/+VLoKq/BLaVKNpeoq4C90Zp1zBqQl4cSnis6/lZ3/xLtc7F4Zwrc35q6U29pCWFl8lcUO6NjbPmH1sY3bGF7qd8kSskNz2NLF9OelMv3slBNJstal/SaVIrVsyJa9R+1wvVWsEUkTTwIHAT/rTWiyKyX1VfK6j2MrBNVd8VkX8P/E/gD8PatZ0CRjKxObXqE9XcjKZ51wADqnpMVWeBR/Fdvt57nOqzqvquu30efz4+FBM0I7mYqFUV37FWIyWgK++W5dLuec1Fcu8q4B7gyXJ9jGRyGkbDEmLG5c6d80WlNcD8PDnom59QZH76okag+ellMqQh2PycmYHW1sYzP6OfpDGmqqWmqSpGRD6BP+X14XJ1bYRmJB8bqVWNCkZo5Yjk3iUiNwJ/AdymqmU98E3QjObARG3xVHcO7UWgX0T6RGQ5cCe+y9ccInIV8Nf4YjZSoo0iTNCM5qEKfmphIfLGPrKl9Ib3vKj1Xlzy+Y3jp1a9vZyqmgXuAw7hH3bxmKq+KiJfFJHbXLX/BawEvicivxSR/QHNzWFzaEZzUWZOTZYtRwLm1LInTvl+ZsuXl5xTW/sPs0x89HLW/vh4yTm11OwsLZs34Z06XXJOTVpaSj67XL+XlCo+X1UP4vutFuZ9oeD6xkrbtBGa0XyUO08tLETe6XeCjx7KZFj74+OB5mduehqdPBNsfjqRq1vzU+0IbsOoT2xObWHYEdyGUaeYqFVOQvZyGkYyMVGrCMnlIqW4MEEzDBO1aCi+Y22UFBMmaIYBJmoREKI51UZ0rK0JJmiGkSfK0UMBq5PeyUGkfVWon9rEDX2k13YWlc+J2vqLgs9Tq5cQebYoYBgNRJhLx8wMpNOBI7XsW6dIresKHKmtOfQGY7deHnye2uQULZs3BofIow5cOkzQDKPBKCNqgX5qqninTgf7qY1P0H0o2PzMnT0bbn6en43X/LQ5NMNoUGxOrSS2ymkYjYqJ2vyGG9/kFJHL3cbQfDojIp8RkU4ReVpEjrrPNa6+iMgD7pzwV0Tk6tq/hmHUCBO1gjZpfEFT1SOqeqWqXgn8FvAu8AP82JyHVbUfOMx7sTp3Av0u7QYeqkXHDWPJMFF7j4TNoW0H3lTVt/DP/97n8vcBd7jr24FH1Od5oCMfv9MwGhYTNaCqBzzWhEoF7U7gO+66pyDe5hCQ/zUqPSvcMBqDxfqprWwjtWpVUXle1CY/tJl0d3dRuZfJoBMZ0j3rgkVtqfzUGt3kzONOlbwN+N78Mhe6rqK3EJHd+QAK5yl7sq5h1AeL8VM7/Q6p7rWBI7X2nxxj7KPvKz1Sm5xCp6dDXTqgjJ/aYlEFLxctxUQlI7SdwC9UNX9y3XDelHSf+SNyI50Vrqp7VXWbqm5bRsBJnYZRjyzUTy3nhZqf3j2Ld/QAAAXuSURBVOhouPk5ObU4P7VqkJQRGnAX75mb4J//vctd7wKeKMi/2612XgtMFZimhpEMGnVObbEkQdBEpA0/wvH3C7K/BNwkIkeBG909+EfqHgMGgK8Bn6pabw2jnmg2UVMgp9FSTESKKaCq08DaeXnj+Kue8+sqcG9VemcY9U5e1EqMSuo67ueCUNAYfTIiYDsFDGOxNMtITUnUooBhGEFEEbUyIfLSHe1F5YWilu5ZV1ReKGqBRw9VVdQSMIdmGEYEoviphYlaW1ugqHU9+SZT1/UFilpudJyWiy8KHqnNM0kXjAmaYTQRYS4d52eRdCrY/BweQTrXlHbpGB6h/bnjjO28NPjooel3y4bIWxwJ2JxuGEaF1GhOzRseWdSc2qJRIJeLlmLCBM0wakGdLhQsGhuhGUaTkjhRS9bWJ8MwKiVJoqagmouU4sIEzTBqTZJErc53CpigGcZSUAU/tbCjh0Zv3hIaIi+1eWNR2YKwOTTDMIBF+allT5wi1b46cKTWdeAImY9cFjxSG6zC+RCqtsppGEYBC/RTI+fhDQ2HhsjrfPZ4cIi8au3ntBGaYRgXENOc2uJR1PMipbgwQTOMOGhEUWuA44NM0AwjLhpS1HLRUkyYoBlGnDSQqCmgOY2UoiAiO0TkiIvhu6dEeauIfNeVvyAim8u1aYJmGHHTKKKmWrURmoikgQfxY5VsBe4Ska3zqt0DZFT1fcD9wJfLtWuCZhj1QBRRCwmRx9qOkmee5UXt7X97aXW6Wb1FgWuAAVU9pqqzwKP4MX0LKYz9+ziwXSQ8Hp9ojEusc50QOQscibsfNaQLGIu7EzUiye8GyXq/S1S1OPBnRETkKfx/jyisAAp9Rfaq6t6Ctj4G7FDVf+fu/xj4bVW9r6DOr1ydQXf/pqsT+HvUYDv+gjiiqtvi7kStEJGXkvp+SX43SP77VYKq7oi7D+Uwk9MwjDiIEr93ro6ItADtwHhYoyZohmHEwYtAv4j0ichy4E78mL6FFMb+/RjwjJaZI6sXk3Nv+SoNTZLfL8nvBsl/v1hQ1ayI3AccAtLAN1T1VRH5IvCSqu4HHga+JSIDwAS+6IVSF4sChmEY1cBMTsMwEoMJmmEYiSF2QSu3/aHeEZGNIvKsiLwmIq+KyKddfqeIPC0iR93nGpcvIvKAe99XROTqeN+gPCKSFpGXReSAu+9zW1EG3NaU5S6/4q0qcSMiHSLyuIj8WkReF5EPJOm3azZiFbSI2x/qnSzwOVXdClwL3OveYQ9wWFX7gcPuHvx37XdpN/DQ0ne5Yj4NvF5w/2XgfrclJYO/RQUWsFWlDvgq8JSqXgG8H/89k/TbNReqGlsCPgAcKrj/PPD5OPtUhXd6ArgJf+fDepe3Ht95GOCvgbsK6s/Vq8eE7x90GLgBOAAIvud8y/zfEH/F6gPuusXVk7jfIeTd2oHj8/uYlN+uGVPcJucG4FTB/aDLa0iciXUV8ALQo6r5c4+HgPzu4EZ7578E/hzI7zheC0yqaj4Ud2H/597NlU+5+vVKHzAK/I0zqb8uIm0k57drOuIWtMQgIiuBvwM+o6pnCsvU/7/zhvOPEZFbgRFV/XncfakRLcDVwEOqehUwzXvmJdC4v12zEregRdn+UPeIyDJ8Mfu2qn7fZQ+LyHpXvh4YcfmN9M4fBG4TkRP4pyHcgD/n1OG2osCF/a94q0rMDAKDqvqCu38cX+CS8Ns1JXELWpTtD3WNO87kYeB1Vf1KQVHhto1d+HNr+fy73YrZtcBUgXlTV6jq51W1V1U34/82z6jqHwHP4m9FgeJ3q2irSpyo6hBwSkQud1nbgddIwG/XtMQ9iQfcArwBvAn8Rdz9WUD/fwffJHkF+KVLt+DPHR0GjgI/BjpdfcFf2X0T+H/AtrjfIeJ7Xg8ccNdbgJ8BA8D3gFaXv8LdD7jyLXH3O8J7XQm85H6/vwfWJO23a6ZkW58Mw0gMcZuchmEYVcMEzTCMxGCCZhhGYjBBMwwjMZigGYaRGEzQDMNIDCZohmEkhv8PHzN1di2QtuMAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r4_dxFmRg0cr",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        },
        "outputId": "61e7db69-aec4-4d27-a7ff-03a42010632b"
      },
      "source": [
        "L = get_graph_laplacian(A.numpy())\n",
        "plt.imshow(L, vmin=0., vmax=1.)\n",
        "plt.colorbar()\n",
        "plt.show()"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAATQAAAD8CAYAAAD5TVjyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAbEUlEQVR4nO3df5Bd5X3f8fdHK4FbHCOEYo2MRCTGG7uadAxkh4gh0xJkt4JmEDN1KExSZKJ0Z1rI2LWnqWg6Tkv7h91OTcwMUaMYYuFxjGXFDhqiWMUCj6eZgCUbl1hgwlqgX/yQASETM4B277d/nOcuV8vu3nN3z93z435ezDN77rlnzz1Hu/vhPM95nvMoIjAza4JFZR+AmVlRHGhm1hgONDNrDAeamTWGA83MGsOBZmaN0ZdAk7RR0lOSxiRt7cdnmFl9SbpH0glJP5zhfUm6M2XI45IuzbPfwgNN0hBwF3A1sA64UdK6oj/HzGrti8DGWd6/GhhOZRTYlmen/bhCuwwYi4hDEfEWcB+wqQ+fY2Y1FRHfAV6ZZZNNwL2ReQRYKmllt/0uLuoAO1wAHO14fQz4lakbSRolS16GtOSX/2G8uw+HYmZtr3HypYj4+bl+/z//tXPi5Vcmcm37vcffPAi80bFqe0Rs7+HjpsuRC4DnZ/umfgRaLunktgO8R8viitU3MX78OfBQLLO++FbsOjyf73/5lQm+u/fCXNsOrXz6jYgYmc/nzUU/qpzHgdUdr1eldbMaP/4ciy94H0h9OCQzm68AWjn/K8CccqQfgbYfGJa0VtJZwA3A7q7fFeFQM6uwIDgdE7lKAXYDN6W7neuBUxExa3UT+lDljIhxSbcCe4Eh4J6IOJjzmydDzdVPs+op6OoLSV8BrgSWSzoG/AGwBCAi/jewB7gGGANeB27Os9++tKFFxJ50QHP5ZoeaWQUFwURBf48RcWOX9wO4pdf9VnOkgKufZpXUInKVslQz0ODtUHvfSrS4tJuxZpYEMEHkKmWpbqABRDDx4gmGLlzlUDOrAF+hzVOMjzNx5BiL1qx2qJmVKIDTEblKWSofaJCFWuvZow41sxJFzuqmq5w5ONTMShYwkbOUpTaBBg41szJlIwXylbLUKtDAoWZWHjGRs5SldoEGZ4Yai4bKPhyzgZDdFFCuUpZaBhqku5+HjjB00YW+UjNbAFk/NF+h9U9rIrtSW+tQM1sIrVCuUpZ6Bxqp+vnMERZd9AsONbM+8hXaAonxcVqHDjvUzPooEBMsylXK0ohAA4ea2UKoepWzUX/5naHWOnSYGB8v+5DMGiMQb0W1exU05gqtzVdqZv2RdaxdlKuUpXGBBh2h5rufZoXyTYGSTN79dKiZFSJCTMSiXKUsjQ00cJcOs6K1UK5Slq6BJukeSSck/bBj3TJJD0p6On09L62XpDsljUl6XNKl/Tz4PNymZlaM7KbA4lylLHmu0L4IbJyybiuwLyKGgX3pNcDVwHAqo8C2Yg5zfhxqZvPXiJsCEfEd4JUpqzcBO9LyDuC6jvX3RuYRYKmklUUd7Hw41MzmbyKUq5RlrlG6omPSzxeAFWn5AuBox3bH0rp3kDQq6YCkA6d5c46H0RuHmtncDcRIgTR/Xs/PqIyI7RExEhEjSzh7voeR/3MdamZz1opFuUpZ5vrJL7arkunribT+OLC6Y7tVaV2luJ+aWe+ywenNvELbDWxOy5uB+zvW35Tudq4HTnVUTStlskuHn3xrlksgTsdQrlKWrn/Jkr4CXAksl3QM+APgM8BOSVuAw8D1afM9wDXAGPA6cHMfjrkwk0++9dhPs64iKLXTbB5dAy0ibpzhrQ3TbBvALfM9qIXkAe1meZXbaTaPasftAvGNArPuAjz0qS4cambdNfWmQCM51MxmFuR7uKMf8FghblMzm142jV21I8NXaNOYDDV36TDr4ImGa2uyS4c735oBaXB6Q0cKDAQ/T83sTL5CqznfKDDLRKjQKzRJGyU9lZ6fuHWa9y+U9LCkx9LzFa/ptk8HWg4ONbP2TYFihj5JGgLuInuG4jrgRknrpmz2n4GdEXEJcAPwR93260DLyaFmVuicApcBYxFxKCLeAu4je55ipwDek5bPBZ7rtlMHWg8cajbIspsCufuhLW8/7zCV0Sm7y/PsxP8C/FYaQ74H+N1ux+i/yh65n5oNsh5GAbwUESPz/LgbgS9GxP+SdDnwJUm/FBGtmb7BV2hz4Oep2SAqeKRAnmcnbgF2AkTE3wDvApbPtlMH2hx53k8bRAVOkrIfGJa0VtJZZI3+u6dsc4T0VB9J/4gs0H4y2079lzgPnf3UXP20pouA061iroEiYlzSrcBeYAi4JyIOSrodOBARu4FPAX8i6d+TNeF9LD2ibEYOtHlym5oNiqzKWVylLiL2kDX2d677dMfyE8AVvezTVc4C+O6nDQqPFBgQDjVruh67bZTCgVYgh5o1W7FDn/qh6ydLWp3GUz0h6aCkj6f1yyQ9KOnp9PW8tF6S7kzjsx6XdGm/T6JKHGrWZK00r0C3UpY8UToOfCoi1gHrgVvSmKutwL6IGAb2pdeQjc0aTmUU2Fb4UVec+6lZE2V3OYdylbJ0DbSIeD4ivp+WXwOeJBuisAnYkTbbAVyXljcB90bmEWBpe1LiQeJHD1nT1OER3D1VdiWtAS4BHgVWdEwi/AKwIi3nGaOFpNH2OK/TvNnjYdeDq5/WNE2ocgIg6d3AnwOfiIifdr6XOrvN2uFtqojYHhEjETGyhLN7+dZacahZUzTmLqekJWRh9uWI+Hpa/WK7Kpm+nkjr84zRGigONWuKJtzlFHA38GREfK7jrd3A5rS8Gbi/Y/1N6W7neuBUR9V0YDnUrO4ixHgsylXKkucv6wrgXwN/K+kHad1/Aj4D7JS0BTgMXJ/e2wNcA4wBrwM3F3rENeZhUlZ3ZVYn8+gaaBHxf2HGVr4N02wfwC3zPK7G6uzS0XrmiEPNaqPdhlZlHilQAj96yOqqETcFrHjup2Z107h+aFYs3yiwuql6PzT/FZXMNwqsLiJgvKAHPPZLtY9uQPhKzerCVU7LxaFmVec2NOuJQ82qLkK5SlkcaBXjRw9ZlVX9poADrYImu3SsWe1Qs8qIcBuazVGMj9N69qirn1YhYqK1KFcpiwOtwtymZlXjNjSbF4eaVUVjnodm5XKoWSVE1o6Wp5TFgVYTDjWrAt/ltMI41KxM4ZsCVrTJUHOXDiuBq5xWuMkuHe58awvMdzmtL/w8NVto2dWXA836xG1qttBq321D0rskfVfS/5N0UNJ/TevXSnpU0pikr0o6K60/O70eS++v6e8pDDaHmi2kJrShvQlcFREfAi4GNqbp6T4L3BER7wdOAlvS9luAk2n9HWk76yOHmi2EQLRai3KVsnT95Mj8fXq5JJUArgJ2pfU7gOvS8qb0mvT+hjS3p/WRQ80WQuQsZck7c/pQmpPzBPAg8GPg1YhoPy/6GHBBWr4AOAqQ3j8FnD/NPkclHZB04DRvzu8sDHCoWZ8VfFNA0kZJT6Xmqa0zbHO9pCdSc9efddtnrt/6iJgALpa0FPgG8MFcRzz7PrcD2wHeo2VlhnqjeN5P66uC/lIlDQF3AR8huyDaL2l3RDzRsc0wcBtwRUSclPTebvvtqbIbEa8CDwOXA0sltQNxFXA8LR8HVqcDWgycC7zcy+fY/LhLh/VLgVdolwFjEXEoIt4C7iNrrur0b4C7IuJk9tlxottO89zl/Pl0ZYakf0CWqE+SBdtH02abgfvT8u70mvT+Q2k2dVtArn5a0QJotZSrAMvbTUqpjE7Z3WTTVNLZbNX2i8AvSvprSY9I2tjtGPP8pq8EdqRLxEXAzoh4QNITwH2S/jvwGHB32v5u4EuSxoBXgBtyfIb1gafIs0IFkL+P2UsRMTLPT1wMDANXktUCvyPpH6ea4ozfMKuIeBy4ZJr1h8guG6eufwP4jfzHbP3kULMiFVjXmmyaSjqbrdqOAY9GxGngGUl/RxZw+2faqUcKDABXP60wxfXb2A8Mpw76Z5HV5HZP2eYvyK7OkLScrAp6aLadOtAGhEPN5i/fDYE8NwVSl65bgb1kbfI7I+KgpNslXZs22wu8nJq3Hgb+Q0TMeoPRv9kDxNVPm7cCb+9FxB5gz5R1n+5YDuCTqeTiK7QB43k/bc4CoqVcpSwOtAHkfmo2d8pZyuFAG1BuU7M5qfhgTgfaAHOoWc8caFZlDjXLrd2xNk8piQPNHGqWWxMe8GgDwKFmubSUr5TEgWaT3KXDulHkK2VxoNkZJrt0ONRsqrw3BBxoViXup2bTy3lDwDcFrGrcpmbT8hWa1ZVDzd6hlbOUxIFms3Ko2ST3Q7MmcKhZm+9yWiM41AxoThtampvzMUkPpNdrJT2a5tT7anrqJJLOTq/H0vtr+nPottDcT82qrpcrtI+TPVmy7bPAHRHxfuAksCWt3wKcTOvvSNtZQ7if2mBrRJVT0irgXwBfSK8FXAXsSpvsAK5Ly5vSa9L7G9L21hDupzaggsYMffpD4Pd4+4bs+cCr6bngcOacepPz7aX3T6XtzyBptD1n32nenOPhW1ncpjag6t6GJunXgRMR8b0iPzgitkfESESMLOHsIndtC8ShNniaUOW8ArhW0rNk07VfBXweWCqp/VvcOafe5Hx76f1zgVlnarH6cqgNmLpfoUXEbRGxKiLWkM2d91BE/CbZtFIfTZttBu5Py7vTa9L7D6XZW6yhHGoDpO6BNov/CHxS0hhZG9ndaf3dwPlp/SeBrfM7RKsDh1rz5a1ullnl7Ok3LyK+DXw7LR8CLptmmzeA3yjg2KxmOvuptZ454nk/m6jEO5h5eKSAFcr91Jqt6ldoDjQrnPupNViD29DMZuQ2tQaqQRuaA836xqHWQL5Cs0HmUGsWtfKVsjjQrO8carZQHGi2IBxqDeEqp1nGz1OrOd8UMDuTu3TUnK/QzM7k6meNOdDM3smhVj/CdznNZuRQq5mC29AkbZT0VJp/ZMaHWEj6l5JC0ki3fTrQrFQOtZopqMopaQi4C7gaWAfcKGndNNv9HNl8Jo/mOTwHmpXOoVYjxbWhXQaMRcShiHiL7OGxm6bZ7r+RTbT0Rp6dOtCsEhxq9dBDlXN5e86QVEan7Gpy7pGkc16S7LOkS4HVEfGXeY/PvzlWGX6eWg3kv4P5UkR0bfOaiaRFwOeAj/Xyfb5Cs0pxP7UKi0Lvck7OPZJ0zksC8HPALwHfTvOZrAd2d7sx4ECzynH1s8KKa0PbDwxLWivpLLL5SnZPfkzEqYhYHhFr0nwmjwDXRsSB2XbqQLNKcqhVU1HdNtKcvbcCe4EngZ0RcVDS7ZKunevx5fpNSZd8rwETwHhEjEhaBnwVWAM8C1wfESfTLOmfB64BXgc+FhHfn+sB2uDqDLXWocNuU6uCAkcBRMQeYM+UdZ+eYdsr8+yzlyu0X4uIizsa+rYC+yJiGNjH27M7XQ0MpzIKbOvhM8zO4Cu1Cslb3azp0KdNwI60vAO4rmP9vZF5hGxC4pXz+BwbcA61ahDNedpGAP9H0vc6+pOsiIjn0/ILwIq03LV/CYCk0XYfldO8OYdDt0HiRw9VQ9UDLe9vxq9GxHFJ7wUelPSjzjcjIqTeTiMitgPbAd6jZSX+E1hdTHbpWLOa1rNH3aZWhor/pea6QouI4+nrCeAbZMMWXmxXJdPXE2nzbv1LzOYsxsdpPXvU1c+y1L0NTdI5aYAoks4B/hnwQ7I+I5vTZpuB+9PybuAmZdYDpzqqpmbz5ja1khT8tI1+yPPbsAL4RtYbg8XAn0XENyXtB3ZK2gIcBq5P2+8h67IxRtZt4+bCj9oGnrt0lKTiVc6ugRYRh4APTbP+ZWDDNOsDuKWQozObhUNt4ZX58MY8PFLAas3Vz4VV9SqnA81qz6G2QBresdasMiZDbc1qh1o/OdDMFsZklw53vu2LJo0UMKsFP0+tv9SKXKUsDjRrHLep9Ynb0MzK4VDrD1c5zUriUOsDX6GZlcehVixfoZmVzKFWIF+hmZXPz1MrQBQ661NfONBsYEx26XCozYn7oZlVjPupzVNEvlISB5oNHLepzZ2v0MwqyKE2B+5Ya1ZdDrXe+aaAWYU51HrjQDOrOIdaToFvCpjVgfup5dOImwKSlkraJelHkp6UdLmkZZIelPR0+npe2laS7pQ0JulxSZf29xTMiuEuHTk05KbA54FvRsQHySZMeRLYCuyLiGFgX3oNcDUwnMoosK3QIzbrI1c/Z9aIjrWSzgX+CXA3QES8FRGvApuAHWmzHcB1aXkTcG9kHgGWtickNqsDh9oMIt/DHav+gMe1wE+AP5X0mKQvpAmHV3RMIPwC2fydABcARzu+/1hadwZJo5IOSDpwmjfnfgZmfeBQm0EDqpyLgUuBbRFxCfAz3q5eArTn4uzpNCJie0SMRMTIEs7u5VvNFoRD7Z1qX+Uku8I6FhGPpte7yALuxXZVMn09kd4/Dqzu+P5VaZ1Z7TjUOgTQinylJF0DLSJeAI5K+kBatQF4AtgNbE7rNgP3p+XdwE3pbud64FRH1dSsdtylo0PFq5x5fzq/C3xZ0lnAIeBmsjDcKWkLcBi4Pm27B7gGGANeT9ua1Vrno4dazxwhxsfLPqRSFFmdlLSRrAfFEPCFiPjMlPc/CfwOME7Wjv/bEXF4tn3mCrSI+AEwMs1bG6bZNoBb8uzXrE46+6m1Dh0eyFAr6g6mpCHgLuAjZM1a+yXtjognOjZ7DBiJiNcl/VvgfwD/arb9eqSAWQ8Guk2t2KdtXAaMRcShiHgLuI+sy9fbHxfxcES8nl4+QtYePysHmlmPBjXUso61kasAy9vdslIZnbK7XN27OmwB/qrbMQ7OT8OsQJ2hNlDVz/xP0ngpIqZrpuqZpN8ia/L6p9229RWa2RwN4pVaD1do3eTq3iXpw8DvA9dGRNce+A40s3kYqFArtg1tPzAsaW3qPXEDWZevSZIuAf6YLMxOTLOPd3Cgmc3T4PRTK24sZ0SMA7cCe8kedrEzIg5Kul3StWmz/wm8G/iapB9I2j3D7iY1+V/fbMFMdulYs5rWs0eb26ZW4MMbI2IPWb/VznWf7lj+cK/79BWaWUFifJzWs0ebW/0MP4LbbKA0vk3Nj+A2GyyNDrWKj+V0oJn1QVNDTa1WrlIWB5pZnzQu1IKsY22eUhIHmlkfNSnURL5OtTk71vaFA82szyZDbc3q2oeabwqY2dtdOure+daBZmZA/ef9dBuamXWqe5ua73Ka2RnqG2o5q5tVrnJK+kAaGNouP5X0CUnLJD0o6en09by0vSTdKWlM0uOSLu3/aZjVSy1DLah/oEXEUxFxcURcDPwy2cQn3yCbm3NfRAwD+3h7rs6rgeFURoFt/Thws7qrZag1rA1tA/DjNPPKJmBHWr8DuC4tbwLujcwjwNL2/J1mdqa6hVrT+qHdAHwlLa/omG/zBWBFWu71WeFmA61Wz1Ore5WzLT1V8lrga1PfS1PX9XQWkkbbEyicpuuTdc0arXPez8qGWgRMtPKVkvRyhXY18P2IeDG9frFdlUxf24/IzfWs8IjYHhEjETGyhLN7P3KzhqlFP7WmXKEBN/J2dROy539vTsubgfs71t+U7nauB051VE3NbBaVb1NrQqBJOodshuOvd6z+DPARSU8DH06vIXuk7iFgDPgT4N8VdrRmA6CyoRZAK/KVkuT614qInwHnT1n3Mtldz6nbBnBLIUdnNqCqOe9nQJTYJyMHjxQwq6jKXakFjbopYGYLrHJdOprQhmZm5emcIo9FQyUfjAPNzOYpxseZOHSEoYvKvFJrwOB0M6uI1kT2kMiynnwbQKuVr5TEgWZWI5NPvi0t1HyFZmYFKi/UmjX0ycwqopRQC4ho5SplcaCZ1VQpoVbxkQIONLMaa4fa0IWrFibU3IZmZv0U4+NMHDnG0Ir3gtTHDwrf5TSz/ovxccafe57FF7yv/6FW4Su0CoylMLNCRDB+/DkWX/A+xo8/14dgCWJiouB9FstXaGZN0hFqhV+p1eDxQQ40s6bpa6i18pWSONDMmqgPoRZAtCJXyUPSRklPpTl8t07z/tmSvpref1TSmm77dKCZNVXRoRZR2BWapCHgLrK5StYBN0paN2WzLcDJiHg/cAfw2W77daCZNVk71FYVM5NkTEzkKjlcBoxFxKGIeAu4j2xO306dc//uAjZIsydzJe5yvsbJv/9W7Hqq7OPoo+XAS2UfRJ80+dygCecXwBEAfmE+u3mNk3u/FbuW59z8XZIOdLzeHhHbO15PN3/vr0zZx+Q2ETEu6RTZVAAz/jwqEWjAUxExUvZB9IukA009vyafGzT//HoRERvLPoZuXOU0szLkmb93chtJi4FzgZdn26kDzczKsB8YlrRW0lnADWRz+nbqnPv3o8BDaVa5GVWlyrm9+ya11uTza/K5QfPPrxSpTexWYC8wBNwTEQcl3Q4ciIjdwN3AlySNAa+Qhd6s1CXwzMxqw1VOM2sMB5qZNUbpgdZt+EPVSVot6WFJT0g6KOnjaf0ySQ9Kejp9PS+tl6Q70/k+LunScs+gO0lDkh6T9EB6vTYNRRlLQ1POSut7HqpSNklLJe2S9CNJT0q6vEk/u0FTaqDlHP5QdePApyJiHbAeuCWdw1ZgX0QMA/vSa8jOdTiVUWDbwh9yzz4OPNnx+rPAHWlIykmyISowh6EqFfB54JsR8UHgQ2Tn2aSf3WCJiNIKcDmwt+P1bcBtZR5TAed0P/AR4ClgZVq3kqzzMMAfAzd2bD+5XRULWf+gfcBVwAOAyHpqL576MyS7Y3V5Wl6ctlPZ5zDLuZ0LPDP1GJvysxvEUnaVc7rhD8UMOitBqmJdAjwKrIiI59NbLwAr0nLdzvkPgd8D2iOOzwdejYjx9Lrz+M8YqgK0h6pU1VrgJ8Cfpir1FySdQ3N+dgOn7EBrDEnvBv4c+ERE/LTzvcj+d167/jGSfh04ERHfK/tY+mQxcCmwLSIuAX7G29VLoL4/u0FVdqDlGf5QeZKWkIXZlyPi62n1i5JWpvdXAifS+jqd8xXAtZKeJXsawlVkbU5L01AUOPP4ex6qUrJjwLGIeDS93kUWcE342Q2ksgMtz/CHSkuPM7kbeDIiPtfxVuewjc1kbWvt9TelO2brgVMd1ZtKiYjbImJVRKwh+9k8FBG/CTxMNhQF3nluPQ1VKVNEvAAclfSBtGoD8AQN+NkNrLIb8YBrgL8Dfgz8ftnHM4fj/1WyKsnjwA9SuYas7Wgf8DTwLWBZ2l5kd3Z/DPwtMFL2OeQ8zyuBB9LyRcB3gTHga8DZaf270uux9P5FZR93jvO6GDiQfn5/AZzXtJ/dIBUPfTKzxii7ymlmVhgHmpk1hgPNzBrDgWZmjeFAM7PGcKCZWWM40MysMf4/xTyv6ZISLfYAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zdOyfo0rg0cu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##########################\n",
        "### MODEL\n",
        "##########################\n",
        "\n",
        "from scipy.sparse.linalg import eigsh\n",
        "        \n",
        "\n",
        "class GraphNet(nn.Module):\n",
        "    def __init__(self, img_size=28, num_filters=2, num_classes=10):\n",
        "        super(GraphNet, self).__init__()\n",
        "        \n",
        "        n_rows = img_size**2\n",
        "        self.fc = nn.Linear(n_rows*num_filters, num_classes, bias=False)\n",
        "\n",
        "        A = precompute_adjacency_matrix(img_size)\n",
        "        L = get_graph_laplacian(A.numpy())\n",
        "        Λ,V = eigsh(L.numpy(), k=20, which='SM') # eigen-decomposition (i.e. find Λ,V)\n",
        "\n",
        "        V = torch.from_numpy(V)\n",
        "        \n",
        "        # Weight matrix\n",
        "        W_spectral = nn.Parameter(torch.ones((img_size**2, num_filters))).float()\n",
        "        torch.nn.init.kaiming_uniform_(W_spectral)\n",
        "        \n",
        "        self.register_buffer('A', A)\n",
        "        self.register_buffer('L', L)\n",
        "        self.register_buffer('V', V)\n",
        "        self.register_buffer('W_spectral', W_spectral)\n",
        "\n",
        "        \n",
        "\n",
        "    def forward(self, x):\n",
        "        \n",
        "        B = x.size(0) # Batch size\n",
        "\n",
        "        ### Reshape eigenvectors\n",
        "        # from [H*W, 20] to [B, H*W, 20]\n",
        "        V_tensor = self.V.unsqueeze(0)\n",
        "        V_tensor = self.V.expand(B, -1, -1)\n",
        "        # from [H*W, 20] to [B, 20, H*W]\n",
        "        V_tensor_T = self.V.T.unsqueeze(0)\n",
        "        V_tensor_T = self.V.T.expand(B, -1, -1)\n",
        "        \n",
        "        ### Reshape inputs\n",
        "        # [B, C, H, W] => [B, H*W, 1]\n",
        "        x_reshape = x.view(B, -1, 1)\n",
        "        \n",
        "        ### Reshape spectral weights\n",
        "        # to size [128, H*W, F]\n",
        "        W_spectral_tensor = self.W_spectral.unsqueeze(0)\n",
        "        W_spectral_tensor = self.W_spectral.expand(B, -1, -1)\n",
        "        \n",
        "        ### Spectral convolution on graphs\n",
        "        # [B, 20, H*W] . [B, H*W, 1]  ==> [B, 20, 1]\n",
        "        X_hat = V_tensor_T.bmm(x_reshape) # 20×1 node features in the \"spectral\" domain\n",
        "        W_hat = V_tensor_T.bmm(W_spectral_tensor)  # 20×F filters in the \"spectral\" domain\n",
        "        Y = V_tensor.bmm(X_hat * W_hat)  # N×F result of convolution\n",
        "\n",
        "        ### Fully connected\n",
        "        logits = self.fc(Y.reshape(B, -1))\n",
        "        probas = F.softmax(logits, dim=1)\n",
        "        return logits, probas"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sMraApdZg0cx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "torch.manual_seed(RANDOM_SEED)\n",
        "model = GraphNet(img_size=IMG_SIZE, num_classes=NUM_CLASSES)\n",
        "\n",
        "model = model.to(DEVICE)\n",
        "\n",
        "optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)  "
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ko0TSuS_hzsA",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 365
        },
        "outputId": "32fca4c3-8a31-4e0c-bbb0-d5dbc0c3f5c2"
      },
      "source": [
        "import hiddenlayer as hl\n",
        "hl.build_graph(model, torch.zeros([128, 1, 28, 28]).to(DEVICE))"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:738: UserWarning: ONNX export failed on ATen operator numpy_T because torch.onnx.symbolic_opset9.numpy_T does not exist\n",
            "  .format(op_name, opset_version, op_name))\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "error",
          "ename": "RuntimeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-13-b3b9b4ae7b3a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mhiddenlayer\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mhl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mhl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDEVICE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/hiddenlayer/graph.py\u001b[0m in \u001b[0;36mbuild_graph\u001b[0;34m(model, args, input_names, transforms, framework_transforms)\u001b[0m\n\u001b[1;32m    141\u001b[0m         \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mpytorch_builder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mimport_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFRAMEWORK_TRANSFORMS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    142\u001b[0m         \u001b[0;32massert\u001b[0m \u001b[0margs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Argument args must be provided for Pytorch models.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m         \u001b[0mimport_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    144\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0mframework\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"tensorflow\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    145\u001b[0m         \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtf_builder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mimport_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFRAMEWORK_TRANSFORMS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/hiddenlayer/pytorch_builder.py\u001b[0m in \u001b[0;36mimport_graph\u001b[0;34m(hl_graph, model, args, input_names, verbose)\u001b[0m\n\u001b[1;32m     69\u001b[0m     \u001b[0;31m# Run the Pytorch graph to get a trace and generate a graph from it\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m     \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_trace_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m     \u001b[0mtorch_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimize_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOperatorExportTypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mONNX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m     \u001b[0;31m# Dump list of nodes (DEBUG only)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py\u001b[0m in \u001b[0;36m_optimize_trace\u001b[0;34m(graph, operator_export_type)\u001b[0m\n\u001b[1;32m    181\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_optimize_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    182\u001b[0m     \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 183\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimize_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    185\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_optimize_graph\u001b[0;34m(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict)\u001b[0m\n\u001b[1;32m    152\u001b[0m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_erase_number_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m         \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_onnx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    155\u001b[0m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_lint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    156\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py\u001b[0m in \u001b[0;36m_run_symbolic_function\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    197\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_symbolic_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_symbolic_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    200\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    201\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_run_symbolic_function\u001b[0;34m(g, n, inputs, env, operator_export_type)\u001b[0m\n\u001b[1;32m    737\u001b[0m                                   \u001b[0;34m\"torch.onnx.symbolic_opset{}.{} does not exist\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    738\u001b[0m                                   .format(op_name, opset_version, op_name))\n\u001b[0;32m--> 739\u001b[0;31m                 \u001b[0mop_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msym_registry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_registered_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    740\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mop_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mattrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    741\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_registry.py\u001b[0m in \u001b[0;36mget_registered_op\u001b[0;34m(opname, domain, version)\u001b[0m\n\u001b[1;32m    107\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    108\u001b[0m             \u001b[0mmsg\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\"Please open a bug to request ONNX export support for the missing operator.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    110\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0m_registry\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mopname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mRuntimeError\u001b[0m: Exporting the operator numpy_T to ONNX opset version 9 is not supported. Please open a bug to request ONNX export support for the missing operator."
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jhh_bMP_jCYV",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 330
        },
        "outputId": "981b2067-7fcc-4288-fbe3-f2732c2f87d6"
      },
      "source": [
        "import tensorwatch as tw\n",
        "tw.draw_model(model, [128, 1, 28, 28])"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "error",
          "ename": "RuntimeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-16-aebaab0584ea>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtensorwatch\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtw\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorwatch/__init__.py\u001b[0m in \u001b[0;36mdraw_model\u001b[0;34m(model, input_shape, orientation, png_filename)\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdraw_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morientation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'TB'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpng_filename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m#orientation = 'LR' for landscpe\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m     \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmodel_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhiddenlayer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpytorch_draw_model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m     \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpytorch_draw_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     36\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorwatch/model_graph/hiddenlayer/pytorch_draw_model.py\u001b[0m in \u001b[0;36mdraw_graph\u001b[0;34m(model, args)\u001b[0m\n\u001b[1;32m     33\u001b[0m         \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m     \u001b[0mdot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdraw_img_classifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     36\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mDotWrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorwatch/model_graph/hiddenlayer/pytorch_draw_model.py\u001b[0m in \u001b[0;36mdraw_img_classifier\u001b[0;34m(model, dataset, display_param_nodes, rankdir, styles, input_shape)\u001b[0m\n\u001b[1;32m     61\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m         \u001b[0mnon_para_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdistiller\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_non_parallel_copy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m         \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSummaryGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnon_para_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdummy_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     65\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0msgraph2dot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisplay_param_nodes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrankdir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstyles\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorwatch/model_graph/hiddenlayer/summary_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, model, dummy_input, apply_scope_name_workarounds)\u001b[0m\n\u001b[1;32m    134\u001b[0m             \u001b[0;31m# Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    135\u001b[0m             \u001b[0;31m# composing a GEMM operation; etc.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 136\u001b[0;31m             \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimize_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOperatorExportTypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mONNX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    137\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    138\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mops\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py\u001b[0m in \u001b[0;36m_optimize_trace\u001b[0;34m(graph, operator_export_type)\u001b[0m\n\u001b[1;32m    181\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_optimize_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    182\u001b[0m     \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 183\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimize_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    185\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_optimize_graph\u001b[0;34m(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict)\u001b[0m\n\u001b[1;32m    152\u001b[0m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_erase_number_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m         \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_onnx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    155\u001b[0m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_lint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    156\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py\u001b[0m in \u001b[0;36m_run_symbolic_function\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    197\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_symbolic_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_symbolic_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    200\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    201\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_run_symbolic_function\u001b[0;34m(g, n, inputs, env, operator_export_type)\u001b[0m\n\u001b[1;32m    737\u001b[0m                                   \u001b[0;34m\"torch.onnx.symbolic_opset{}.{} does not exist\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    738\u001b[0m                                   .format(op_name, opset_version, op_name))\n\u001b[0;32m--> 739\u001b[0;31m                 \u001b[0mop_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msym_registry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_registered_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    740\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mop_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mattrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    741\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_registry.py\u001b[0m in \u001b[0;36mget_registered_op\u001b[0;34m(opname, domain, version)\u001b[0m\n\u001b[1;32m    107\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    108\u001b[0m             \u001b[0mmsg\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\"Please open a bug to request ONNX export support for the missing operator.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    110\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0m_registry\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mopname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mRuntimeError\u001b[0m: Exporting the operator numpy_T to ONNX opset version 9 is not supported. Please open a bug to request ONNX export support for the missing operator."
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NGrQF-YPg0c0",
        "colab_type": "text"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Bdi4wqbCg0c0",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "588159b9-c8b0-469e-c4b5-680e556377e5"
      },
      "source": [
        "def compute_acc(model, data_loader, device):\n",
        "    correct_pred, num_examples = 0, 0\n",
        "    for features, targets in data_loader:\n",
        "        features = features.to(device)\n",
        "        targets = targets.to(device)\n",
        "        logits, probas = model(features)\n",
        "        _, predicted_labels = torch.max(probas, 1)\n",
        "        num_examples += targets.size(0)\n",
        "        correct_pred += (predicted_labels == targets).sum()\n",
        "    return correct_pred.float()/num_examples * 100\n",
        "    \n",
        "\n",
        "start_time = time.time()\n",
        "\n",
        "cost_list = []\n",
        "train_acc_list, valid_acc_list = [], []\n",
        "\n",
        "\n",
        "for epoch in range(NUM_EPOCHS):\n",
        "    \n",
        "    model.train()\n",
        "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
        "        \n",
        "        features = features.to(DEVICE)\n",
        "        targets = targets.to(DEVICE)\n",
        "            \n",
        "        ### FORWARD AND BACK PROP\n",
        "        logits, probas = model(features)\n",
        "        cost = F.cross_entropy(logits, targets)\n",
        "        optimizer.zero_grad()\n",
        "        \n",
        "        cost.backward()\n",
        "        \n",
        "        ### UPDATE MODEL PARAMETERS\n",
        "        optimizer.step()\n",
        "        \n",
        "        #################################################\n",
        "        ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n",
        "        ################################################\n",
        "        cost_list.append(cost.item())\n",
        "        if not batch_idx % 150:\n",
        "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
        "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n",
        "                   f' Cost: {cost:.4f}')\n",
        "\n",
        "        \n",
        "\n",
        "    model.eval()\n",
        "    with torch.set_grad_enabled(False): # save memory during inference\n",
        "        \n",
        "        train_acc = compute_acc(model, train_loader, device=DEVICE)\n",
        "        valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n",
        "        \n",
        "        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n",
        "              f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n",
        "        \n",
        "        train_acc_list.append(train_acc)\n",
        "        valid_acc_list.append(valid_acc)\n",
        "        \n",
        "    elapsed = (time.time() - start_time)/60\n",
        "    print(f'Time elapsed: {elapsed:.2f} min')\n",
        "  \n",
        "elapsed = (time.time() - start_time)/60\n",
        "print(f'Total Training Time: {elapsed:.2f} min')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 001/050 | Batch 000/461 | Cost: 2.3048\n",
            "Epoch: 001/050 | Batch 150/461 | Cost: 1.2503\n",
            "Epoch: 001/050 | Batch 300/461 | Cost: 0.9677\n",
            "Epoch: 001/050 | Batch 450/461 | Cost: 0.9317\n",
            "Epoch: 001/050\n",
            "Train ACC: 75.17 | Validation ACC: 79.60\n",
            "Time elapsed: 0.22 min\n",
            "Epoch: 002/050 | Batch 000/461 | Cost: 0.9826\n",
            "Epoch: 002/050 | Batch 150/461 | Cost: 0.7964\n",
            "Epoch: 002/050 | Batch 300/461 | Cost: 0.8447\n",
            "Epoch: 002/050 | Batch 450/461 | Cost: 0.9356\n",
            "Epoch: 002/050\n",
            "Train ACC: 77.79 | Validation ACC: 82.50\n",
            "Time elapsed: 0.43 min\n",
            "Epoch: 003/050 | Batch 000/461 | Cost: 0.8216\n",
            "Epoch: 003/050 | Batch 150/461 | Cost: 0.7701\n",
            "Epoch: 003/050 | Batch 300/461 | Cost: 0.7437\n",
            "Epoch: 003/050 | Batch 450/461 | Cost: 0.7210\n",
            "Epoch: 003/050\n",
            "Train ACC: 78.63 | Validation ACC: 83.00\n",
            "Time elapsed: 0.64 min\n",
            "Epoch: 004/050 | Batch 000/461 | Cost: 0.7137\n",
            "Epoch: 004/050 | Batch 150/461 | Cost: 0.7224\n",
            "Epoch: 004/050 | Batch 300/461 | Cost: 0.7001\n",
            "Epoch: 004/050 | Batch 450/461 | Cost: 0.7026\n",
            "Epoch: 004/050\n",
            "Train ACC: 79.43 | Validation ACC: 83.60\n",
            "Time elapsed: 0.86 min\n",
            "Epoch: 005/050 | Batch 000/461 | Cost: 0.7232\n",
            "Epoch: 005/050 | Batch 150/461 | Cost: 0.7898\n",
            "Epoch: 005/050 | Batch 300/461 | Cost: 0.6990\n",
            "Epoch: 005/050 | Batch 450/461 | Cost: 0.6734\n",
            "Epoch: 005/050\n",
            "Train ACC: 79.75 | Validation ACC: 84.00\n",
            "Time elapsed: 1.07 min\n",
            "Epoch: 006/050 | Batch 000/461 | Cost: 0.6119\n",
            "Epoch: 006/050 | Batch 150/461 | Cost: 0.6370\n",
            "Epoch: 006/050 | Batch 300/461 | Cost: 0.5354\n",
            "Epoch: 006/050 | Batch 450/461 | Cost: 0.5525\n",
            "Epoch: 006/050\n",
            "Train ACC: 80.24 | Validation ACC: 84.40\n",
            "Time elapsed: 1.29 min\n",
            "Epoch: 007/050 | Batch 000/461 | Cost: 0.5528\n",
            "Epoch: 007/050 | Batch 150/461 | Cost: 0.7764\n",
            "Epoch: 007/050 | Batch 300/461 | Cost: 0.6738\n",
            "Epoch: 007/050 | Batch 450/461 | Cost: 0.6649\n",
            "Epoch: 007/050\n",
            "Train ACC: 80.38 | Validation ACC: 85.10\n",
            "Time elapsed: 1.50 min\n",
            "Epoch: 008/050 | Batch 000/461 | Cost: 0.5831\n",
            "Epoch: 008/050 | Batch 150/461 | Cost: 0.5694\n",
            "Epoch: 008/050 | Batch 300/461 | Cost: 0.6280\n",
            "Epoch: 008/050 | Batch 450/461 | Cost: 0.6157\n",
            "Epoch: 008/050\n",
            "Train ACC: 80.82 | Validation ACC: 84.80\n",
            "Time elapsed: 1.71 min\n",
            "Epoch: 009/050 | Batch 000/461 | Cost: 0.6353\n",
            "Epoch: 009/050 | Batch 150/461 | Cost: 0.4995\n",
            "Epoch: 009/050 | Batch 300/461 | Cost: 0.6532\n",
            "Epoch: 009/050 | Batch 450/461 | Cost: 0.6075\n",
            "Epoch: 009/050\n",
            "Train ACC: 81.06 | Validation ACC: 85.30\n",
            "Time elapsed: 1.93 min\n",
            "Epoch: 010/050 | Batch 000/461 | Cost: 0.5638\n",
            "Epoch: 010/050 | Batch 150/461 | Cost: 0.4921\n",
            "Epoch: 010/050 | Batch 300/461 | Cost: 0.5796\n",
            "Epoch: 010/050 | Batch 450/461 | Cost: 0.8327\n",
            "Epoch: 010/050\n",
            "Train ACC: 81.18 | Validation ACC: 85.00\n",
            "Time elapsed: 2.14 min\n",
            "Epoch: 011/050 | Batch 000/461 | Cost: 0.6995\n",
            "Epoch: 011/050 | Batch 150/461 | Cost: 0.7460\n",
            "Epoch: 011/050 | Batch 300/461 | Cost: 0.6491\n",
            "Epoch: 011/050 | Batch 450/461 | Cost: 0.6694\n",
            "Epoch: 011/050\n",
            "Train ACC: 81.51 | Validation ACC: 86.40\n",
            "Time elapsed: 2.36 min\n",
            "Epoch: 012/050 | Batch 000/461 | Cost: 0.6969\n",
            "Epoch: 012/050 | Batch 150/461 | Cost: 0.6250\n",
            "Epoch: 012/050 | Batch 300/461 | Cost: 0.6810\n",
            "Epoch: 012/050 | Batch 450/461 | Cost: 0.7608\n",
            "Epoch: 012/050\n",
            "Train ACC: 81.70 | Validation ACC: 86.60\n",
            "Time elapsed: 2.57 min\n",
            "Epoch: 013/050 | Batch 000/461 | Cost: 0.6763\n",
            "Epoch: 013/050 | Batch 150/461 | Cost: 0.6019\n",
            "Epoch: 013/050 | Batch 300/461 | Cost: 0.6995\n",
            "Epoch: 013/050 | Batch 450/461 | Cost: 0.7679\n",
            "Epoch: 013/050\n",
            "Train ACC: 82.01 | Validation ACC: 86.90\n",
            "Time elapsed: 2.79 min\n",
            "Epoch: 014/050 | Batch 000/461 | Cost: 0.7620\n",
            "Epoch: 014/050 | Batch 150/461 | Cost: 0.6022\n",
            "Epoch: 014/050 | Batch 300/461 | Cost: 0.7215\n",
            "Epoch: 014/050 | Batch 450/461 | Cost: 0.6713\n",
            "Epoch: 014/050\n",
            "Train ACC: 81.84 | Validation ACC: 86.30\n",
            "Time elapsed: 3.00 min\n",
            "Epoch: 015/050 | Batch 000/461 | Cost: 0.6591\n",
            "Epoch: 015/050 | Batch 150/461 | Cost: 0.6754\n",
            "Epoch: 015/050 | Batch 300/461 | Cost: 0.6556\n",
            "Epoch: 015/050 | Batch 450/461 | Cost: 0.6266\n",
            "Epoch: 015/050\n",
            "Train ACC: 82.06 | Validation ACC: 86.80\n",
            "Time elapsed: 3.21 min\n",
            "Epoch: 016/050 | Batch 000/461 | Cost: 0.5202\n",
            "Epoch: 016/050 | Batch 150/461 | Cost: 0.4029\n",
            "Epoch: 016/050 | Batch 300/461 | Cost: 0.5752\n",
            "Epoch: 016/050 | Batch 450/461 | Cost: 0.5685\n",
            "Epoch: 016/050\n",
            "Train ACC: 82.14 | Validation ACC: 86.60\n",
            "Time elapsed: 3.43 min\n",
            "Epoch: 017/050 | Batch 000/461 | Cost: 0.5394\n",
            "Epoch: 017/050 | Batch 150/461 | Cost: 0.4913\n",
            "Epoch: 017/050 | Batch 300/461 | Cost: 0.5236\n",
            "Epoch: 017/050 | Batch 450/461 | Cost: 0.6327\n",
            "Epoch: 017/050\n",
            "Train ACC: 82.19 | Validation ACC: 87.40\n",
            "Time elapsed: 3.64 min\n",
            "Epoch: 018/050 | Batch 000/461 | Cost: 0.6354\n",
            "Epoch: 018/050 | Batch 150/461 | Cost: 0.5510\n",
            "Epoch: 018/050 | Batch 300/461 | Cost: 0.6542\n",
            "Epoch: 018/050 | Batch 450/461 | Cost: 0.5690\n",
            "Epoch: 018/050\n",
            "Train ACC: 82.22 | Validation ACC: 87.10\n",
            "Time elapsed: 3.85 min\n",
            "Epoch: 019/050 | Batch 000/461 | Cost: 0.5523\n",
            "Epoch: 019/050 | Batch 150/461 | Cost: 0.4327\n",
            "Epoch: 019/050 | Batch 300/461 | Cost: 0.6313\n",
            "Epoch: 019/050 | Batch 450/461 | Cost: 0.9467\n",
            "Epoch: 019/050\n",
            "Train ACC: 82.56 | Validation ACC: 87.30\n",
            "Time elapsed: 4.07 min\n",
            "Epoch: 020/050 | Batch 000/461 | Cost: 0.6790\n",
            "Epoch: 020/050 | Batch 150/461 | Cost: 0.5775\n",
            "Epoch: 020/050 | Batch 300/461 | Cost: 0.5814\n",
            "Epoch: 020/050 | Batch 450/461 | Cost: 0.6594\n",
            "Epoch: 020/050\n",
            "Train ACC: 82.60 | Validation ACC: 87.70\n",
            "Time elapsed: 4.28 min\n",
            "Epoch: 021/050 | Batch 000/461 | Cost: 0.7137\n",
            "Epoch: 021/050 | Batch 150/461 | Cost: 0.7988\n",
            "Epoch: 021/050 | Batch 300/461 | Cost: 0.5479\n",
            "Epoch: 021/050 | Batch 450/461 | Cost: 0.6271\n",
            "Epoch: 021/050\n",
            "Train ACC: 82.71 | Validation ACC: 88.10\n",
            "Time elapsed: 4.50 min\n",
            "Epoch: 022/050 | Batch 000/461 | Cost: 0.5507\n",
            "Epoch: 022/050 | Batch 150/461 | Cost: 0.7333\n",
            "Epoch: 022/050 | Batch 300/461 | Cost: 0.4814\n",
            "Epoch: 022/050 | Batch 450/461 | Cost: 0.6881\n",
            "Epoch: 022/050\n",
            "Train ACC: 82.54 | Validation ACC: 87.30\n",
            "Time elapsed: 4.71 min\n",
            "Epoch: 023/050 | Batch 000/461 | Cost: 0.7118\n",
            "Epoch: 023/050 | Batch 150/461 | Cost: 0.6506\n",
            "Epoch: 023/050 | Batch 300/461 | Cost: 0.4700\n",
            "Epoch: 023/050 | Batch 450/461 | Cost: 0.6796\n",
            "Epoch: 023/050\n",
            "Train ACC: 82.74 | Validation ACC: 87.20\n",
            "Time elapsed: 4.93 min\n",
            "Epoch: 024/050 | Batch 000/461 | Cost: 0.5192\n",
            "Epoch: 024/050 | Batch 150/461 | Cost: 0.7034\n",
            "Epoch: 024/050 | Batch 300/461 | Cost: 0.6526\n",
            "Epoch: 024/050 | Batch 450/461 | Cost: 0.7725\n",
            "Epoch: 024/050\n",
            "Train ACC: 82.78 | Validation ACC: 87.80\n",
            "Time elapsed: 5.14 min\n",
            "Epoch: 025/050 | Batch 000/461 | Cost: 0.7274\n",
            "Epoch: 025/050 | Batch 150/461 | Cost: 0.5939\n",
            "Epoch: 025/050 | Batch 300/461 | Cost: 0.5960\n",
            "Epoch: 025/050 | Batch 450/461 | Cost: 0.6475\n",
            "Epoch: 025/050\n",
            "Train ACC: 83.00 | Validation ACC: 88.30\n",
            "Time elapsed: 5.36 min\n",
            "Epoch: 026/050 | Batch 000/461 | Cost: 0.4952\n",
            "Epoch: 026/050 | Batch 150/461 | Cost: 0.6846\n",
            "Epoch: 026/050 | Batch 300/461 | Cost: 0.6472\n",
            "Epoch: 026/050 | Batch 450/461 | Cost: 0.5632\n",
            "Epoch: 026/050\n",
            "Train ACC: 82.94 | Validation ACC: 87.40\n",
            "Time elapsed: 5.57 min\n",
            "Epoch: 027/050 | Batch 000/461 | Cost: 0.5120\n",
            "Epoch: 027/050 | Batch 150/461 | Cost: 0.6289\n",
            "Epoch: 027/050 | Batch 300/461 | Cost: 0.7053\n",
            "Epoch: 027/050 | Batch 450/461 | Cost: 0.4083\n",
            "Epoch: 027/050\n",
            "Train ACC: 83.16 | Validation ACC: 88.00\n",
            "Time elapsed: 5.78 min\n",
            "Epoch: 028/050 | Batch 000/461 | Cost: 0.5392\n",
            "Epoch: 028/050 | Batch 150/461 | Cost: 0.5031\n",
            "Epoch: 028/050 | Batch 300/461 | Cost: 0.5022\n",
            "Epoch: 028/050 | Batch 450/461 | Cost: 0.7608\n",
            "Epoch: 028/050\n",
            "Train ACC: 83.17 | Validation ACC: 88.30\n",
            "Time elapsed: 6.00 min\n",
            "Epoch: 029/050 | Batch 000/461 | Cost: 0.4466\n",
            "Epoch: 029/050 | Batch 150/461 | Cost: 0.4515\n",
            "Epoch: 029/050 | Batch 300/461 | Cost: 0.4774\n",
            "Epoch: 029/050 | Batch 450/461 | Cost: 0.4977\n",
            "Epoch: 029/050\n",
            "Train ACC: 83.15 | Validation ACC: 88.40\n",
            "Time elapsed: 6.21 min\n",
            "Epoch: 030/050 | Batch 000/461 | Cost: 0.6674\n",
            "Epoch: 030/050 | Batch 150/461 | Cost: 0.5109\n",
            "Epoch: 030/050 | Batch 300/461 | Cost: 0.5185\n",
            "Epoch: 030/050 | Batch 450/461 | Cost: 0.5111\n",
            "Epoch: 030/050\n",
            "Train ACC: 83.29 | Validation ACC: 87.80\n",
            "Time elapsed: 6.43 min\n",
            "Epoch: 031/050 | Batch 000/461 | Cost: 0.5706\n",
            "Epoch: 031/050 | Batch 150/461 | Cost: 0.5534\n",
            "Epoch: 031/050 | Batch 300/461 | Cost: 0.6899\n",
            "Epoch: 031/050 | Batch 450/461 | Cost: 0.5967\n",
            "Epoch: 031/050\n",
            "Train ACC: 83.19 | Validation ACC: 88.70\n",
            "Time elapsed: 6.64 min\n",
            "Epoch: 032/050 | Batch 000/461 | Cost: 0.5113\n",
            "Epoch: 032/050 | Batch 150/461 | Cost: 0.3937\n",
            "Epoch: 032/050 | Batch 300/461 | Cost: 0.5300\n",
            "Epoch: 032/050 | Batch 450/461 | Cost: 0.7212\n",
            "Epoch: 032/050\n",
            "Train ACC: 83.27 | Validation ACC: 88.10\n",
            "Time elapsed: 6.86 min\n",
            "Epoch: 033/050 | Batch 000/461 | Cost: 0.5365\n",
            "Epoch: 033/050 | Batch 150/461 | Cost: 0.7785\n",
            "Epoch: 033/050 | Batch 300/461 | Cost: 0.5429\n",
            "Epoch: 033/050 | Batch 450/461 | Cost: 0.5651\n",
            "Epoch: 033/050\n",
            "Train ACC: 83.16 | Validation ACC: 87.40\n",
            "Time elapsed: 7.07 min\n",
            "Epoch: 034/050 | Batch 000/461 | Cost: 0.5087\n",
            "Epoch: 034/050 | Batch 150/461 | Cost: 0.4693\n",
            "Epoch: 034/050 | Batch 300/461 | Cost: 0.4562\n",
            "Epoch: 034/050 | Batch 450/461 | Cost: 0.5066\n",
            "Epoch: 034/050\n",
            "Train ACC: 83.49 | Validation ACC: 88.50\n",
            "Time elapsed: 7.29 min\n",
            "Epoch: 035/050 | Batch 000/461 | Cost: 0.5604\n",
            "Epoch: 035/050 | Batch 150/461 | Cost: 0.4309\n",
            "Epoch: 035/050 | Batch 300/461 | Cost: 0.5422\n",
            "Epoch: 035/050 | Batch 450/461 | Cost: 0.4938\n",
            "Epoch: 035/050\n",
            "Train ACC: 83.52 | Validation ACC: 88.60\n",
            "Time elapsed: 7.50 min\n",
            "Epoch: 036/050 | Batch 000/461 | Cost: 0.7996\n",
            "Epoch: 036/050 | Batch 150/461 | Cost: 0.5987\n",
            "Epoch: 036/050 | Batch 300/461 | Cost: 0.6045\n",
            "Epoch: 036/050 | Batch 450/461 | Cost: 0.4677\n",
            "Epoch: 036/050\n",
            "Train ACC: 83.46 | Validation ACC: 88.90\n",
            "Time elapsed: 7.72 min\n",
            "Epoch: 037/050 | Batch 000/461 | Cost: 0.6250\n",
            "Epoch: 037/050 | Batch 150/461 | Cost: 0.6493\n",
            "Epoch: 037/050 | Batch 300/461 | Cost: 0.6661\n",
            "Epoch: 037/050 | Batch 450/461 | Cost: 0.4785\n",
            "Epoch: 037/050\n",
            "Train ACC: 83.54 | Validation ACC: 88.40\n",
            "Time elapsed: 7.93 min\n",
            "Epoch: 038/050 | Batch 000/461 | Cost: 0.5079\n",
            "Epoch: 038/050 | Batch 150/461 | Cost: 0.5097\n",
            "Epoch: 038/050 | Batch 300/461 | Cost: 0.6581\n",
            "Epoch: 038/050 | Batch 450/461 | Cost: 0.5454\n",
            "Epoch: 038/050\n",
            "Train ACC: 83.53 | Validation ACC: 88.10\n",
            "Time elapsed: 8.15 min\n",
            "Epoch: 039/050 | Batch 000/461 | Cost: 0.4846\n",
            "Epoch: 039/050 | Batch 150/461 | Cost: 0.5818\n",
            "Epoch: 039/050 | Batch 300/461 | Cost: 0.6538\n",
            "Epoch: 039/050 | Batch 450/461 | Cost: 0.6154\n",
            "Epoch: 039/050\n",
            "Train ACC: 83.69 | Validation ACC: 88.50\n",
            "Time elapsed: 8.37 min\n",
            "Epoch: 040/050 | Batch 000/461 | Cost: 0.6323\n",
            "Epoch: 040/050 | Batch 150/461 | Cost: 0.6033\n",
            "Epoch: 040/050 | Batch 300/461 | Cost: 0.6337\n",
            "Epoch: 040/050 | Batch 450/461 | Cost: 0.6556\n",
            "Epoch: 040/050\n",
            "Train ACC: 83.48 | Validation ACC: 88.20\n",
            "Time elapsed: 8.58 min\n",
            "Epoch: 041/050 | Batch 000/461 | Cost: 0.5198\n",
            "Epoch: 041/050 | Batch 150/461 | Cost: 0.6337\n",
            "Epoch: 041/050 | Batch 300/461 | Cost: 0.5417\n",
            "Epoch: 041/050 | Batch 450/461 | Cost: 0.7118\n",
            "Epoch: 041/050\n",
            "Train ACC: 83.75 | Validation ACC: 88.80\n",
            "Time elapsed: 8.80 min\n",
            "Epoch: 042/050 | Batch 000/461 | Cost: 0.5934\n",
            "Epoch: 042/050 | Batch 150/461 | Cost: 0.4491\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XP-b4DCHg0c3",
        "colab_type": "text"
      },
      "source": [
        "## Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5VX3-miug0c3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(cost_list, label='Minibatch cost')\n",
        "plt.plot(np.convolve(cost_list, \n",
        "                     np.ones(200,)/200, mode='valid'), \n",
        "         label='Running average')\n",
        "\n",
        "plt.ylabel('Cross Entropy')\n",
        "plt.xlabel('Iteration')\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cb5xqLRXg0c5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n",
        "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n",
        "\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HisgxVvDg0c8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "with torch.set_grad_enabled(False):\n",
        "    test_acc = compute_acc(model=model,\n",
        "                           data_loader=test_loader,\n",
        "                           device=DEVICE)\n",
        "    \n",
        "    valid_acc = compute_acc(model=model,\n",
        "                            data_loader=valid_loader,\n",
        "                            device=DEVICE)\n",
        "    \n",
        "\n",
        "print(f'Validation ACC: {valid_acc:.2f}%')\n",
        "print(f'Test ACC: {test_acc:.2f}%')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S2fF6D2Ug0c-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%watermark -iv"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}